import time
from tqdm import tqdm
import json
import jsonlines
import numpy as np
from fuzzywuzzy import fuzz
import spacy
import re

def parse_answers(texts):
    pattern = r'which is better\?\s*([^.]*\.)'
    step_keys = [
        'step1_optiona', 'step1_optionb', 'step1_equal',
        'step2_optiona', 'step2_optionb', 'step2_equal',
        'step3_optiona', 'step3_optionb', 'step3_equal',
        'step4_optiona', 'step4_optionb', 'step4_equal'
    ]

    # Initialize aggregated results
    aggregated_results = {key: 0.0 for key in step_keys}

    for text in texts:
        matches = re.findall(pattern, text, re.IGNORECASE)

        for i, match in enumerate(matches):
            step_key = f'step{i+1}_'
            if 'option a' in match.lower():
                aggregated_results[step_key + 'optiona'] += 1
            elif 'option b' in match.lower():
                aggregated_results[step_key + 'optionb'] += 1
            else:
                aggregated_results[step_key + 'equal'] += 1

    # Calculate the average for each key
    num_texts = len(texts)
    average_results = {key: value / num_texts for key, value in aggregated_results.items()}

    return average_results


if __name__ == "__main__":
    dataset = jsonlines.open('./gpt4_ans/winogavil/anscot/swow/test.jsonl', mode='r') 
    texts = []
    with tqdm(desc='Process', unit='it', total=51) as pbar: #5_6: (260); 10_12: (85); swow: (84)
        for line in dataset.iter():
            text = line["gpt4_rate"]
            texts.append(text)
            pbar.update()

    average_results = parse_answers(texts)
    for key, value in average_results.items():
        print(f'{key} = {value:.2f}')